/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef IMPL_REDUCE_REDUCE_MAX_REDUCE_MAX_C310_IMPL_H_
#define IMPL_REDUCE_REDUCE_MAX_REDUCE_MAX_C310_IMPL_H_

#include "kernel_tensor.h"
#include "kernel_operator_intf.h"
#include "kernel_tiling/kernel_tiling.h"
#include "../../common/check.h"
#include "../../api_check/kernel_api_check.h"
#include "../reduce_common_util_impl.h"
#include "../reduce_common_util_c310_impl.h"
#include "../reduce_common_ar_ra_reuse_unalign_c310_impl.h"
#include "../reduce_common_ar_reuse_align_c310_impl.h"
#include "../reduce_common_ra_reuse_align_c310_impl.h"

namespace AscendC {
namespace Internal {
template <class T, bool isReuseSource = false>
__aicore__ inline void ReduceMaxARB64ReuseSourceCompute(__ubuf__ T *dstAddr, __ubuf__ T *srcAddr,
    __ubuf__ T *tmpAddr, const uint32_t srcShape[])
{
    if ((srcShape[1] * sizeof(T)) % 32 == 0) {
        ReduceARImpl<T, MicroAPI::RegTraitNumTwo,
            MicroAPI::Max<T, MicroAPI::MaskMergeMode::ZEROING,
            MicroAPI::RegTensor<T, MicroAPI::RegTraitNumTwo>>,
            MicroAPI::ReduceMax<T, MicroAPI::MaskMergeMode::ZEROING,
            MicroAPI::RegTensor<T, MicroAPI::RegTraitNumTwo>>, isReuseSource>(
                dstAddr, srcAddr, tmpAddr, srcShape[0], srcShape[1]);
    } else {
        ReduceARReuseSourceUnAligned<T, MicroAPI::RegTraitNumTwo,
            MicroAPI::Max<T, MicroAPI::MaskMergeMode::ZEROING,
            MicroAPI::RegTensor<T, MicroAPI::RegTraitNumTwo>>,
            MicroAPI::ReduceMax<T, MicroAPI::MaskMergeMode::ZEROING,
            MicroAPI::RegTensor<T, MicroAPI::RegTraitNumTwo>>>(
                dstAddr, srcAddr, srcShape[0], srcShape[1]);
    }
}

template <class T, bool isReuseSource = false>
__aicore__ inline void ReduceMaxARReuseSourceCompute(__ubuf__ T *dstAddr, __ubuf__ T *srcAddr,
    __ubuf__ T *tmpAddr, const uint32_t srcShape[])
{
    using castType = typename ReduceOpInternal::ExtractReduceCastType<T>::CastT;
    if ((srcShape[1] * sizeof(T)) % 32 == 0 || srcShape[1] == 1) {
        ReduceARImpl<T, MicroAPI::RegTraitNumOne,
            MicroAPI::Max<T, MicroAPI::MaskMergeMode::ZEROING, MicroAPI::RegTensor<T>>,
            MicroAPI::ReduceMax<castType, MicroAPI::MaskMergeMode::ZEROING, MicroAPI::RegTensor<castType>>,
            isReuseSource, ReduceType::IS_REDUCE_MAX>(
                dstAddr, srcAddr, tmpAddr, srcShape[0], srcShape[1]);
    } else {
        if constexpr (SupportBytes<T, 1>()) {
            constexpr uint32_t ONE_STRIDE = 32768;
            uint16_t repeatTimes = (srcShape[0] + ONE_STRIDE - 1) >> 15;
            for (uint16_t i = 0; i < repeatTimes; ++i) {
                uint32_t dimA =
                    ONE_STRIDE < srcShape[0] - ONE_STRIDE * i ? ONE_STRIDE : srcShape[0] - ONE_STRIDE * i;
                ReduceARReuseSourceUnAligned<T, MicroAPI::RegTraitNumOne,
                    MicroAPI::Max<T, MicroAPI::MaskMergeMode::ZEROING, MicroAPI::RegTensor<T>>,
                    MicroAPI::ReduceMax<castType, MicroAPI::MaskMergeMode::ZEROING, MicroAPI::RegTensor<castType>>>(
                        dstAddr + ONE_STRIDE * i, srcAddr + ONE_STRIDE * i * srcShape[1], dimA, srcShape[1]);
            }
        } else {
            ReduceARReuseSourceUnAligned<T, MicroAPI::RegTraitNumOne,
                MicroAPI::Max<T, MicroAPI::MaskMergeMode::ZEROING, MicroAPI::RegTensor<T>>,
                MicroAPI::ReduceMax<castType, MicroAPI::MaskMergeMode::ZEROING, MicroAPI::RegTensor<castType>>>(
                    dstAddr, srcAddr, srcShape[0], srcShape[1]);
        }
    }
}

template <class T, bool isReuseSource = false>
__aicore__ inline void ReduceMaxRAB64ReuseSourceCompute(__ubuf__ T *dstAddr, __ubuf__ T *srcAddr,
    __ubuf__ T *tmpAddr, const uint32_t srcShape[])
{
    if ((srcShape[1] * sizeof(T)) % 32 == 0) {
        ReduceRAB64ReuseSource<T, MicroAPI::RegTraitNumTwo,
            MicroAPI::Max<T, MicroAPI::MaskMergeMode::ZEROING,
            MicroAPI::RegTensor<T, MicroAPI::RegTraitNumTwo>>, isReuseSource>(
                dstAddr, srcAddr, tmpAddr, srcShape[1], srcShape[0]);
    } else {
        ReduceRAReuseSourceUnAlignedB64<T, MicroAPI::RegTraitNumTwo,
            MicroAPI::Max<T, MicroAPI::MaskMergeMode::ZEROING,
            MicroAPI::RegTensor<T, MicroAPI::RegTraitNumTwo>>>(
                dstAddr, srcAddr, srcShape[1], srcShape[0]);
    }
}

template <class T, bool isReuseSource = false>
__aicore__ inline void ReduceMaxRAReuseSourceCompute(__ubuf__ T *dstAddr, __ubuf__ T *srcAddr,
    __ubuf__ T *tmpAddr, const uint32_t srcShape[])
{
    if ((srcShape[1] * sizeof(T)) % 32 == 0) {
        ReduceRAImpl<T, MicroAPI::RegTraitNumOne,
            MicroAPI::Max<T, MicroAPI::MaskMergeMode::ZEROING,
            MicroAPI::RegTensor<T>>, isReuseSource>(
                dstAddr, srcAddr, tmpAddr, srcShape[1], srcShape[0]);
    } else {
        ReduceRAReuseSourceUnAligned<T, MicroAPI::RegTraitNumOne,
            MicroAPI::Max<T, MicroAPI::MaskMergeMode::ZEROING,
            MicroAPI::RegTensor<T>>>(
                dstAddr, srcAddr, srcShape[1], srcShape[0]);
    }
}

template <class T, class pattern, bool isReuseSource = false>
__aicore__ inline void ReduceMaxImpl(const LocalTensor<T> &dst, const LocalTensor<T> &src, const LocalTensor<uint8_t> &sharedTmpBuffer,
    const uint32_t srcShape[], bool srcInnerPad)
{
    CheckTensorPos<T>(dst, Hardware::UB, "dstTensor", "VECIN / VECCALC / VECOUT", "ReduceMax");
    CheckTensorPos<T>(src, Hardware::UB, "srcTensor", "VECIN / VECCALC / VECOUT", "ReduceMax");
    CheckTensorPos<uint8_t>(sharedTmpBuffer, Hardware::UB, "sharedTmpBuffer", "VECIN / VECCALC / VECOUT", "ReduceMax");
    static_assert(SupportType<T, int8_t, uint8_t, int16_t, uint16_t, half, bfloat16_t, int32_t, uint32_t, float,
        int64_t, uint64_t>(),
        "ReduceMax only support "
        "int8_t/uint8_t/int16_t/uint16_t/half/bfloat16_t/int32_t/uint32_t/float/int64_t/uint64_t data type on current "
        "device!");
    static_assert(std::is_same_v<pattern, Pattern::Reduce::AR> || std::is_same_v<pattern, Pattern::Reduce::RA>,
        "ReduceMax only support AR and RA pattern on current device!");

    __ubuf__ T *dstAddr = (__ubuf__ T *)dst.GetPhyAddr();
    __ubuf__ T *srcAddr = (__ubuf__ T *)src.GetPhyAddr();
    LocalTensor<T> tmpBuf = sharedTmpBuffer.ReinterpretCast<T>();
    __ubuf__ T *tmpAddr = (__ubuf__ T *)tmpBuf.GetPhyAddr();
    if constexpr (std::is_same_v<pattern, Pattern::Reduce::AR>) {
        if constexpr (SupportBytes<T, 8>()) {
            ReduceMaxARB64ReuseSourceCompute<T, isReuseSource>(dstAddr, srcAddr, tmpAddr, srcShape);
        } else {
            ReduceMaxARReuseSourceCompute<T, isReuseSource>(dstAddr, srcAddr, tmpAddr, srcShape);
        }
    } else {
        if constexpr (SupportBytes<T, 8>()) {
            ReduceMaxRAB64ReuseSourceCompute<T, isReuseSource>(dstAddr, srcAddr, tmpAddr, srcShape);
        } else {
            ReduceMaxRAReuseSourceCompute<T, isReuseSource>(dstAddr, srcAddr, tmpAddr, srcShape);
        }
    }
}
} // namespace Internal
} // namespace AscendC
#endif // IMPL_REDUCE_REDUCE_MAX_REDUCE_MAX_C310_IMPL_H_